
include("../src/qmcmary.jl")
using ..qmcmary

using Test
using LinearAlgebra



function sgn_scratch(ss::ScrollSVD{T}) where T
    siz = size(ss.B[end])
    ptr = findfirst(ss.L)
    if isnothing(ptr)
        VL = Diagonal(ones(siz[1]))
        DL = VL
        UL = VL
        UR, DR, VR = ss.F[end].U, Diagonal(ss.F[end].S), ss.F[end].Vt
    elseif ptr == 1
        VL, DL, UL = ss.F[1].U, Diagonal(ss.F[1].S), ss.F[1].Vt
        UR = Diagonal(ones(siz[1]))
        DR = UR
        VR = UR
        ##上面这个数值稳定性会突然出问题
        ##用下面这个
        ##VL DL UL UR=I DR=I VR=I -> I I I UR=VL DR=DL VR=UL
        #VL = Diagonal(ones(siz[1]))
        #DL = VL
        #UL = VL
        #UR, DR, VR = ss.F[1].U, Diagonal(ss.F[1].S), ss.F[1].Vt
    else
        VL, DL, UL = ss.F[ptr].U, Diagonal(ss.F[ptr].S), ss.F[ptr].Vt
        UR, DR, VR = ss.F[ptr-1].U, Diagonal(ss.F[ptr-1].S), ss.F[ptr-1].Vt
    end
    #gtt = inv(Diagonal(ones(siz[1]))+UR*DR*VR*VL*DL*UL)
    #M = inv(UL*UR) + DR*(VR*VL)*DL
    #Fm = svd(M)
    #gtt = inv(Fm.Vt*UL)*inv(Diagonal(Fm.S))*inv(UR*Fm.U)
    #
    DLS = Diagonal(ones(Float64, siz[1]))
    DLB = Diagonal(ones(Float64, siz[1]))
    DRS = Diagonal(ones(Float64, siz[1]))
    DRB = Diagonal(ones(Float64, siz[1]))
    for i in Base.OneTo(siz[1])
        if DL[i, i] > 1.0
            DLB[i, i] = DL[i, i]
            DLS[i, i] = 1.0
        else
            DLS[i, i] = DL[i, i]
            DLB[i, i] = 1.0
        end
        if DR[i, i] > 1.0
            DRB[i, i] = DR[i, i]
            DRS[i, i] = 1.0
        else
            DRS[i, i] = DR[i, i]
            DRB[i, i] = 1.0
        end
    end
    #
    M = inv(DRB)*adjoint(UL*UR)*inv(DLB) + DRS*(VR*VL)*DLS
    Fm = svd(M, alg=LinearAlgebra.QRIteration())
    #ML = adjoint(UL)*inv(DLB)*adjoint(Fm.Vt)
    #MR = adjoint(Fm.U)*inv(DRB)*adjoint(UR)
    #gtt = ML*inv(Diagonal(Fm.S))*MR
    #增加计算phase
    sgn = det(Fm.Vt*UL)*det(UR*Fm.U)
    sgn = log(abs(sgn))
    for sval in Fm.S
        sgn += log(sval)
    end
    for sval in diag(DLB)
        sgn += log(sval)
    end
    for sval in diag(DRB)
        sgn += log(sval)
    end
    return sgn
end


function run(L)
    Nx = 3*L^2
    hk = lattice_kagome(ComplexF64, L, -1.0+0.0im)
    Ui = 9*ones(Nx)
    Nt = 60
    Ng = 4
    psi = π*0.25*ones(Nx)
    the = π*0.25*ones(Nx)
    nx = sin.(psi).*sin.(the)
    ny = cos.(psi).*sin.(the)
    nz = cos.(the)
    #
    hscfg = Matrix{Int}(undef, Nt, Nx)
    for bi in Base.OneTo(Nt)
        cfg = 2(round.(rand(Nx)) .- 0.5)
        hscfg[bi, :] .= cfg
    end
    sp = default_splitting2(Nt, hk, Ui)
    sslen, bmats, allbmats, ss = initialize_SS(
        Nt, Ng, sp, hscfg, nx, ny, nz
    )
    #
    xidx = 1
    pxr, pxi, pyr, pyi, pzr, pzi = pbpn_calc_xi(sp, hscfg[:, xidx], xidx,
    nx[xidx], ny[xidx], nz[xidx])
    tapemap = pbpn_calc_tapemap(sp, nx, ny, nz)
    println(hscfg[1:5, xidx])
    println(pxr[1:5, :, 1:5])
    #
    pbrpnx2 = zeros(Nt, 2, 2*Nx)
    pbipnx2 = zeros(Nt, 2, 2*Nx)
    pbrpny2 = zeros(Nt, 2, 2*Nx)
    pbipny2 = zeros(Nt, 2, 2*Nx)
    pbrpnz2 = zeros(Nt, 2, 2*Nx)
    pbipnz2 = zeros(Nt, 2, 2*Nx)
    for bi in Base.OneTo(5)
        println(sp.vecl[bi])
        typeidx = hscfg[bi, xidx] == 1 ? 2*sp.vecl[bi] : 2*sp.vecl[bi]-1
        tape = tapemap[xidx, typeidx]
        jac1 = reshape(tape.output, 4, 2*Nx, 3)
        pbrpnx2[bi, :, :] = jac1[1:2, :, 1]
        pbrpny2[bi, :, :] = jac1[1:2, :, 2]
        pbrpnz2[bi, :, :] = jac1[1:2, :, 3]
        pbipnx2[bi, :, :] = jac1[3:4, :, 1]
        pbipny2[bi, :, :] = jac1[3:4, :, 2]
        pbipnz2[bi, :, :] = jac1[3:4, :, 3]
    end
    #
    println(pbrpnx2[1:5, :, 1:5])
    #
    nxbar, nybar, nzbar = meas_grad(ss, allbmats, sp, hscfg, nx, ny, nz; tapemap=tapemap)
    S1 = sgn_scratch(ss)
    #
    nx2 = copy(nx)
    ny2 = copy(ny)
    nz2 = copy(nz)
    nx2[6] += 0.01
    sslen, bmats, allbmats, ss = initialize_SS(
        Nt, Ng, sp, hscfg, nx2, ny2, nz2
    )
    S2 = sgn_scratch(ss)
    println(-S2+S1)
    println(nxbar[6])
end

#run(2)


function compare_sgn(L)
    Nx = 3*L^2
    hk = lattice_kagome(ComplexF64, L, -1.0+0.0im)
    Ui = 6*ones(Nx)
    Nt = 60
    Ng = 4
    psi = π*0.25*ones(Nx)
    the = π*0.25*ones(Nx)
    nx = sin.(psi).*sin.(the)
    ny = cos.(psi).*sin.(the)
    nz = cos.(the)
    #
    hscfg = Matrix{Int}(undef, Nt, Nx)
    #
    for i in 0:1:15
        #hs_ = digits(i, base=2, pad=4)
        #for j in Base.OneTo(4)
        #    hscfg[j] = hs_[j]
        #end
        #println(hs_, hscfg)
        for bi in Base.OneTo(Nt)
            cfg = 2(round.(rand(Nx)) .- 0.5)
            hscfg[bi, :] .= cfg
        end
        sp = default_splitting(Nt, hk, Ui)
        sslen, bmats, allbmats, ss = initialize_SS(
            Nt, Ng, sp, hscfg, nx, ny, nz
        )
        sp2 = default_splitting3(Nt, hk, Ui)
        sslen2, bmats2, allbmats2, ss2 = initialize_SS(
            Nt, Ng, sp2, hscfg, nx, ny, nz
        )
        #
        gf, ph = eq_green_scratch(ss)
        gf2, ph2 = eq_green_scratch(ss2)
        println("$ph $ph2")
        val1 = sgn_scratch(ss)
        val2 = sgn_scratch(ss2)
        println("$val1 $val2")
    end
end

compare_sgn(2)
